import os.path as osp
import ipdb
from tqdm import tqdm
import argparse
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, DeepGraphInfomax
from torch_geometric.data import GraphSAINTRandomWalkSampler, NeighborSampler
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import subgraph
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.data import DataLoader
import scipy.sparse as ss
import numpy as np


class Encoder(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(Encoder, self).__init__()
        self.conv = SAGEConv(in_channels, hidden_channels)
        self.prelu = nn.PReLU(hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = self.prelu(x)
        return x


def corruption(x, edge_index):
    return x[torch.randperm(x.size(0))], edge_index


def train(model, optimizer, loader, device):
    model.train()
    
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        pos_z, neg_z, summary = model(data.x, data.edge_index)
        loss = model.loss(pos_z, neg_z, summary)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    return total_loss / len(loader)

@torch.no_grad()
def test(model, data, subgraph_loader, split_idx, device, SAVEPATH=None):
    model.eval()
#     z, _, _ = model(data.x, data.edge_index)
    z = inference(model, data.x, subgraph_loader, device)
    if SAVEPATH is not None:
        torch.save(z,SAVEPATH)
# '/home/ec2-user/Eli/SSL_baselines/DGI_embedding/OGB_feature.pt'        
    acc = model.test(z[split_idx['train']], data.y[split_idx['train']].view(-1),
                     z[split_idx['test']], data.y[split_idx['test']].view(-1), max_iter=500)
    return acc

def inference(model, x_all, subgraph_loader, device):
        pbar = tqdm(total=x_all.size(0) * 1)
        pbar.set_description('Evaluating')

#         for i, conv in enumerate(self.convs):
        xs = []
        for batch_size, n_id, adj in subgraph_loader:
            edge_index, _, size = adj.to(device)
            x = x_all[n_id].to(device)
            x_target = x[:size[1]]
            x = model((x, x_target), edge_index, EVAL = True)
            xs.append(x.detach().cpu())
            pbar.update(batch_size)

        x_all = torch.cat(xs, dim=0)

        pbar.close()

        return x_all

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--hidden_channels', type=int, default=768)
    parser.add_argument('--save_name', type=str, default='OGB_feature')
    parser.add_argument('--input_feature_path', type=str, default='None')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--device', type=int, default=0)
    args = parser.parse_args()

    cuda = args.device

    SAVEPATH = './DGI_embedding/input_{}.pt'.format(args.save_name)

    dataset = PygNodePropPredDataset(name =  "ogbn-products", root = "../../dataset/")
    split_idx = dataset.get_idx_split()
    data = dataset[0]

    # Replace node features here
    if args.input_feature_path != 'None':
        data.x = torch.tensor(np.load(args.input_feature_path))
        print("Pretrained node features loaded! Path: {}".format(args.input_feature_path))

    device = torch.device('cuda:'+str(cuda))
    model = DeepGraphInfomax(
        hidden_channels=args.hidden_channels, encoder=Encoder(data.x.size(1), args.hidden_channels),
        summary=lambda z, x, edge_index: torch.sigmoid(z.mean(dim=0)),
        corruption=corruption).to(device)




    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    loader = GraphSAINTRandomWalkSampler(data,
                                         batch_size=10000,
                                         walk_length=1,
                                         num_steps=30,
                                         sample_coverage=0,
                                         save_dir=dataset.processed_dir)

    subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1],
                                          batch_size=4096, shuffle=False,
                                          num_workers=12)


    for epoch in range(1, args.epochs + 1):
        loss = train(model, optimizer, loader, device)
        print('Epoch: {:03d}, Loss: {:.4f}'.format(epoch, loss))
    acc = test(model, data, subgraph_loader, split_idx, device, SAVEPATH)
    print('Accuracy: {:.4f}'.format(acc))

if __name__ == "__main__":
    main()